import carla
import numpy as np
from typing import List

class LineSegment:
    def __init__(self, start: carla.Location, end: carla.Location):
        self.start = start
        self.end = end

    def intersect(self, other):
        """
        Check if two line segments intersect
        :param other: another line segment
        :return: True if intersect, False otherwise
        """
        x1, y1 = self.start.x, self.start.y
        x2, y2 = self.end.x, self.end.y
        x3, y3 = other.start.x, other.start.y
        x4, y4 = other.end.x, other.end.y
        denom = (x1-x2)*(y3-y4) - (y1-y2)*(x3-x4)
        if denom == 0: 
            # Parallel
            return False
        t = ((x1-x3)*(y3-y4) - (y1-y3)*(x3-x4)) / denom
        if t < 0 or t > 1: 
            # No intersection
            return False
        u = -((x1-x2)*(y1-y3) - (y1-y2)*(x1-x3)) / denom
        if u < 0 or u > 1: 
            # No intersection
            return False
        # Intersection
        x = x1 + t*(x2-x1)
        y = y1 + u*(y2-y1)
        return True

    def contain(self, point):
        """
        Check if a point is contained in the line segment
        :param point: a point
        :return: True if contained, False otherwise
        """
        x1, y1 = self.start.x, self.start.y
        x2, y2 = self.end.x, self.end.y
        x3, y3 = point.x, point.y
        if abs(x1-x3) < 1e-1 and abs(y1-y3) < 1e-1: return True
        if abs(x2-x3) < 1e-1 and abs(y2-y3) < 1e-1: return True
        return False

def cartesian_to_polar(start: carla.Location, end: carla.Location):
    """
    Convert a line segment from cartesian coordinate to polar coordinate
    :param start: start point
    :param end: end point
    :return: (r, theta) in polar coordinate
    """
    dx = end.x - start.x
    dy = end.y - start.y
    r = np.sqrt(dx**2 + dy**2)
    theta = np.arctan2(dy, dx)
    return r, theta

def merge_angle_intervals(angle_intervals: List[List[float]],
                          obj_centers: List[carla.Location],
                          obj_segments: List[List[LineSegment]]
                          ):
    """
    Merge a list of angle intervals
    :param angle_intervals: a list of angle intervals, each interval is a list of two floats
    :return: a list of merged angle intervals
    """
    if not angle_intervals: return []
    #angle_intervals = sorted(angle_intervals, key=lambda x: x[0])
    merged_intervals = [angle_intervals[0]]
    occlusions = [{"ids": [0],
                   "angle_interval": angle_intervals[0],
                   "segments": obj_segments[0],
                   "centers": obj_centers[0].location}
                   ]
    for i in range(1, len(angle_intervals)):
        #if angle_intervals[i][0] <= merged_intervals[-1][1]:
        #    merged_intervals[-1][1] = max(merged_intervals[-1][1], angle_intervals[i][1])
        #    occlusions[-1]["ids"].append(i)
        #    occlusions[-1]["angle_interval"] = merged_intervals[-1]
        #    occlusions[-1]["segments"].extend(obj_segments[i])
        #else:
        #    merged_intervals.append(angle_intervals[i])
            occlusions.append({"ids": [i],
                               "angle_interval": angle_intervals[i],
                               "segments": obj_segments[i],
                               "centers": obj_centers[i].location}
                               )
    #print(occlusions)
    return occlusions

def check_num_occluded_vertices(ego_location: carla.Location,
                                obj_vertices: List[carla.Location],
                                occlusions: List[dict]):
    """
    Check the number of occluded vertices of an object
    :param obj_vetices: a list of vertices of the object
    :param occlusion: an occlusion interval
    :return: the number of occluded vertices
    """
    accuracy = 2e-2
    num_occluded_vertices = 0
    for vertice in obj_vertices:
        linesegment = LineSegment(ego_location, vertice)
        angle = cartesian_to_polar(ego_location, vertice)[1]
        #print("Checking")
        #print(angle)
        #print(occlusions)
        for occlusion in occlusions:
            check_intersect = False
            if occlusion["angle_interval"][1] > np.pi and angle < 0:
                if angle < occlusion["angle_interval"][1] - 2*np.pi + accuracy:
                    check_intersect = True
            elif occlusion["angle_interval"][0] - accuracy <= angle <= occlusion["angle_interval"][1] + accuracy:
                check_intersect = True
            if check_intersect:
                #print(ego_location.distance(occlusion["centers"]))
                #print(ego_location.distance(vertice))
                if ego_location.distance(occlusion["centers"]) < ego_location.distance(vertice):
                   num_occluded_vertices += 1
                   break
            #if check_intersect:
            #    for segment in occlusion["segments"]:
            #        print("ego-vertice", linesegment.start.x, linesegment.start.y, linesegment.end.x, linesegment.end.y)
            #        print("object segment", segment.start.x, segment.start.y, segment.end.x, segment.end.y)
            #        if segment.contain(vertice):
            #            print('skipping')
            #        elif segment.intersect(linesegment):
            #            print('intersect')
            #            num_occluded_vertices += 1
            #            break
            #from IPython import embed; embed()
        #print(num_occluded_vertices)
    return num_occluded_vertices

def process_vertices(vertices):
    """
    Process the vertices of the object
    :param vertices: a list of vertices of the object, there should be 8 of them
    :return: a list of 4 vertices of the object
    """
    # TODO: check correctness
    vertices = [vertices[0], vertices[2], vertices[6], vertices[4]]
    #print(vertices[0].x, vertices[0].y, vertices[0].z)
    #print(vertices[1].x, vertices[1].y, vertices[1].z)
    #print(vertices[2].x, vertices[2].y, vertices[2].z)
    #print(vertices[3].x, vertices[3].y, vertices[3].z)
    return vertices

def check_center_visibility(ego_location: carla.Location,
                            obj_center: carla.Location,
                            occlusion: dict):
    angle = cartesian_to_polar(ego_location, obj_center)[1]
    if occlusion["angle_interval"][1] > np.pi and angle < 0:
        if angle < occlusion["angle_interval"][1] - 2*np.pi + accuracy:
                check_intersect = True
    elif occlusion["angle_interval"][0] - accuracy <= angle <= occlusion["angle_interval"][1] + accuracy:
                check_intersect = True
    if check_intersect:
        if ego_location.distance(occlusion["centers"]) < ego_location.distance(obj_center):
            return True
    return False

def check_2d_visibility(ego_location: carla.Location,
                        obj_list: List,
                        visible_range: float = 100):
    """
    This function checks if the object is visible from the center of the ego vehicle
    Assumptions:
        (1) 2D
        (2) 360 degree view
        (3) Rectangular bounding box
        (4) Non-intersecting objects
        (5) If the center point is visible, the whole object is visible
    Algorithm:
        (0) Prepare data, transform all objects into 4 corners in world coordinate
        (1) Scan all objects, find the occlusion angles
        (2) sort and merge the occlusion angle intervals [angle_interval] and [segments]
        (3) check the segment of go vehicle and the four corners of the object,
            If it intersects with any in [angle_interval], the corner is invisible,
            within the continuous interval, the object is invisible if three of the corners
            are invisible
        O(n^2) complexity
    :param ego_location: ego vehicle location, in the form of carla.Location
    :param obj_list: list of objects, in the form of carla.BoundingBox
    :visible_range: visible range, in the form of float
    """
    if not obj_list: return []
    visible_obj_mask = np.ones(len(obj_list), dtype=bool)
    # (0) Prepare object list
    obj_vertices = []
    obj_segments = []
    ego_location.z = 0
    for obj in obj_list:
        obj.location.z = 0
        vertices = obj.get_local_vertices()
        vertices = process_vertices(vertices)
        obj_vertices.append(vertices)
        obj_segments.append([LineSegment(vertices[i], vertices[(i+1) % 4]) for i in range(4)])
    # (1) Scan all objects, find the occlusion angles
    occlusion_angle_intervals = []
    for obj_idx in range(len(obj_list)):
        vertices = obj_vertices[obj_idx]
        angles = []
        for vertice in vertices:
            angle = cartesian_to_polar(ego_location, vertice)[1]
            angles.append(angle)
        # Convex objects can not have a angle interval larger than 180 degree
        angles.sort()
        if angles[-1] - angles[0] > np.pi:
            occlusion_angle_intervals.append([angles[2], angles[1] + 2*np.pi])
        else:
            occlusion_angle_intervals.append([angles[0], angles[-1]])
    # (2) sort and merge the occlusion angle intervals [angle_interval] and [segments]
    occlusions = merge_angle_intervals(occlusion_angle_intervals, obj_list, obj_segments)
    # (3) check if the segment of ego vehicle and the four corners of the object intersets with any occluded area (should not be itself)
    for i in range(len(obj_list)):
        obj = obj_list[i].location
        #print(obj.x, obj.y, obj.z)
        for j in range(len(occlusions)):
            if i == j: continue
            occlusion = occlusions[j]
            #num_occluded_vertices = check_num_occluded_vertices(ego_location, obj, occlusions)
            #print(num_occluded_vertices)
            #if num_occluded_vertices >= 4:
            #    visible_obj_mask[i] = False
            #    break
            occulded = check_num_occluded_vertices(ego_location, [obj], [occlusion])
            if occulded>=1:
                visible_obj_mask[i] = False
                break
    #print(visible_obj_mask)
    #if visible_obj_mask.all():
    #    from IPython import embed; embed()
    return visible_obj_mask
